{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建语音识别系统 - 解码与评测"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Help1： 本节课的内容主要为如何将模型的输出转化到识别结果（通过CTC解码和词典实现）和计算CER指标。**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 如何从模型输出到识别文本"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Help2： 模型的输出维度需要让大家了解一下，notebook中的 (b, len, vocab_size)这三位代表的是什么。（开始训练模型的同学可能已经知道了，但这里要强调一下），分别为批大小，批次中序列最大长度（tensor需要保证一批数据中的长度一致，长度不够的要经过padding，因此我们的dataloader里面会有一个变量存储每个音频的长度），字典大小**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型输出的结果为 (b, len, vocab_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Help3： 贪婪搜索是对每一帧的预测选取概率最大的索引作为该帧的识别结果。对应的是notebook里面的greedy_search函数，这个函数是将模型的输出结果进行处理，对每一帧选取概率最大的位置，转换为字典对应的字符id。**\n",
    "\n",
    "**该函数的输入为模型输出结果和该batch中每个音频的帧的长度。返回为一个len为batchsize的列表，每个元素为识别结果（字典中对应字符的id）的列表**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### greedy search\n",
    "\n",
    "每一步选取预测概率最大的词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tokenizer.tokenizer import Tokenizer\n",
    "tokenizer = Tokenizer()\n",
    "\n",
    "def greedy_search(ctc_probs: torch.Tensor, encoder_out_lens: torch.Tensor):\n",
    "    batch_size, maxlen = ctc_probs.size()[:2]\n",
    "    _, topk_index = ctc_probs.topk(1, dim=2)  # [batch, seq_len, 1]\n",
    "    topk_index = topk_index.squeeze(-1)  # [batch, seq_len]\n",
    "    encoder_out_lens = encoder_out_lens.cpu().tolist()\n",
    "\n",
    "    hyps = []\n",
    "    for i in range(batch_size):\n",
    "        # 获取有效长度内的序列\n",
    "        seq = topk_index[i, :encoder_out_lens[i]].cpu().tolist()\n",
    "        \n",
    "        # 移除重复字符和blank\n",
    "        prev = None\n",
    "        hyp = []\n",
    "        for char in seq:\n",
    "            if char == tokenizer.blk_id():  # 过滤blank\n",
    "                prev = None\n",
    "                continue\n",
    "            if char != prev:  # 去重复\n",
    "                hyp.append(char)\n",
    "                prev = char\n",
    "        hyps.append(hyp)\n",
    "    \n",
    "    return hyps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Help4： example2.pt是我本地训练好的模型的一个输出样例，通过这个样例来帮助大家了解这个流程**\n",
    "\n",
    "**下面的cell里面，打印的是该batch中的每个音频的长度**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([46, 51, 44, 44, 41, 49, 48, 48, 74, 93, 44, 49, 50, 51, 58, 50])\n"
     ]
    }
   ],
   "source": [
    "tensordict = torch.load(\"./example2.pt\")\n",
    "\n",
    "pre = tensordict[\"pre\"].to(\"cpu\")\n",
    "lens = tensordict[\"lens\"].to(\"cpu\")\n",
    "\n",
    "print(lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = greedy_search(pre, lens)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Help5： 下面展示了经过贪婪搜索以后，对该batch的第一个音频的处理结果。**\n",
    "\n",
    "**可以看到这里有一些特殊字符，这里需要大家做到下面的工作： 1.大家在上一节课了解了CTC的解码过程，请大家根据自己的理解来对这个输出进行处理；2.把特殊字符处理，例如\\<sos\\>和\\<eos\\>等**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 40, 188, 227, 247, 243, 375, 360, 32, 87, 251, 291, 282, 32, 141, 243, 55, 317, 3]\n",
      "['<sos>', 'chen', 'pin', 'mao', 'hen', 'si', 'chi', 'zong', 'tiao', 'lian', 'jie', 'wei', 'pen', 'tiao', 'luan', 'si', 'zhua', 'nie', '<eos>']\n"
     ]
    }
   ],
   "source": [
    "from tokenizer.tokenizer import Tokenizer\n",
    "tokenizer = Tokenizer()\n",
    "\n",
    "print(res[0])\n",
    "\n",
    "print(tokenizer.decode(res[0], ignore_special=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**TODO：** 请大家根据CTC的解码思路，将模型的输出进行解码，移除重复字符和blank(上面的版本没有移除重复字符和blank)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 评测识别结果"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Help6： 计算CER指标。这个指标的计算涉及最小编辑距离（同学们应该在之前的算法课学习动态规划的时候实现过）。大家求S，D，I的时候，可以调Python包实现，也可以自己实现。最后计算总的CER的时候，请使用：$CER_{final} = \\frac{S_{total}+D_{total}+I_{total}}{N_{total}}$，也就是大家需要计算总的S，D，I，N来求CER指标。**\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里我们采用字错率(CER, character error rate)来评测ASR系统的性能，计算公式如下:\n",
    "\n",
    "$$CER = \\frac{S+D+I}{N}$$\n",
    "\n",
    "pre 代表模型预测， gt 代表正确识别结果。与最小编辑距离一致，将pre转化成gt，其中，S代表将 pre 转化成 gt 需要替换的数量，D 代表将 pre转化成 gt 需要删除的数量，I 代表将 pre 转化成 gt 需要插入的数量，N 代表gt 的长度。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**TODO：** 根据最小编辑距离求出 S，D，I，N ，完成ASR的CER指标评测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_cer(pre_tokens: list, gt_tokens: list) -> tuple:\n",
    "    m, n = len(pre_tokens), len(gt_tokens)\n",
    "    \n",
    "    # 初始化动态规划表\n",
    "    dp = [[0] * (n+1) for _ in range(m+1)]\n",
    "    for i in range(m+1):\n",
    "        dp[i][0] = i  # 删除所有pre字符\n",
    "    for j in range(n+1):\n",
    "        dp[0][j] = j  # 插入所有gt字符\n",
    "        \n",
    "    # 填充DP表\n",
    "    for i in range(1, m+1):\n",
    "        for j in range(1, n+1):\n",
    "            if pre_tokens[i-1] == gt_tokens[j-1]:\n",
    "                cost = 0\n",
    "            else:\n",
    "                cost = 1\n",
    "            dp[i][j] = min(\n",
    "                dp[i-1][j] + 1,    # 删除操作\n",
    "                dp[i][j-1] + 1,    # 插入操作\n",
    "                dp[i-1][j-1] + cost # 替换或匹配\n",
    "            )\n",
    "    \n",
    "    # 回溯统计S/D/I\n",
    "    i, j = m, n\n",
    "    S = D = I = 0\n",
    "    \n",
    "    while i > 0 and j > 0:\n",
    "        if pre_tokens[i-1] == gt_tokens[j-1]:\n",
    "            # 匹配，无需操作\n",
    "            i -= 1\n",
    "            j -= 1\n",
    "        else:\n",
    "            # 优先选择操作次数最小的路径\n",
    "            if dp[i][j] == dp[i-1][j-1] + 1:\n",
    "                # 替换操作\n",
    "                S += 1\n",
    "                i -= 1\n",
    "                j -= 1\n",
    "            elif dp[i][j] == dp[i-1][j] + 1:\n",
    "                # 删除操作\n",
    "                D += 1\n",
    "                i -= 1\n",
    "            else:\n",
    "                # 插入操作\n",
    "                I += 1\n",
    "                j -= 1\n",
    "    \n",
    "    # 处理剩余字符\n",
    "    D += i  # 剩余pre字符需删除\n",
    "    I += j  # 剩余gt字符需插入\n",
    "    \n",
    "    N = len(gt_tokens)\n",
    "    cer = (S + D + I) / N if N != 0 else 0.0\n",
    "    return cer, S, D, I, N"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CTCModel(\n",
       "  (subsampling): Subsampling(\n",
       "    (subsampling): Conv2dSubsampling8(\n",
       "      (conv): Sequential(\n",
       "        (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2))\n",
       "        (1): ReLU()\n",
       "        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
       "        (3): ReLU()\n",
       "        (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
       "        (5): ReLU()\n",
       "      )\n",
       "      (linear): Linear(in_features=2304, out_features=256, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (positional_encoding): RelPositionalEncoding(\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       "  (encoder): ConformerEncoder(\n",
       "    (layers): ModuleList(\n",
       "      (0-2): 3 x ConformerBlock(\n",
       "        (ff1_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "        (ff1_linear1): Linear(in_features=256, out_features=1024, bias=True)\n",
       "        (ff1_activation): SiLU()\n",
       "        (ff1_dropout1): Dropout(p=0.1, inplace=False)\n",
       "        (ff1_linear2): Linear(in_features=1024, out_features=256, bias=True)\n",
       "        (ff1_dropout2): Dropout(p=0.1, inplace=False)\n",
       "        (attn): MultiheadAttention(\n",
       "          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n",
       "        )\n",
       "        (attn_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "        (conv): ConformerConvModule(\n",
       "          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
       "          (glu): GLU(dim=1)\n",
       "          (depthwise_conv): Conv1d(256, 256, kernel_size=(31,), stride=(1,), padding=(15,), groups=256)\n",
       "          (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (activation): SiLU()\n",
       "          (conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff2_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "        (ff2_linear1): Linear(in_features=256, out_features=1024, bias=True)\n",
       "        (ff2_activation): SiLU()\n",
       "        (ff2_dropout1): Dropout(p=0.1, inplace=False)\n",
       "        (ff2_linear2): Linear(in_features=1024, out_features=256, bias=True)\n",
       "        (ff2_dropout2): Dropout(p=0.1, inplace=False)\n",
       "        (final_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (fc_out): Linear(in_features=256, out_features=412, bias=True)\n",
       "  (ctc_loss): CTCLoss()\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from data.dataloader import get_dataloader\n",
    "from model.model import CTCModel\n",
    "import torch\n",
    "from utils.utils import to_device\n",
    "from tqdm import tqdm\n",
    "\n",
    "dev_dataloader = get_dataloader(\"./dataset/split/dev/wav.scp\", \"./dataset/split/dev/pinyin\", 32, tokenizer, shuffle=False)\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model = CTCModel(80, 256, tokenizer.size(), tokenizer.blk_id()).to(device)\n",
    "\n",
    "checkpoint = torch.load(\"./model.pt\", map_location=device)\n",
    "model.load_state_dict(checkpoint['model'])\n",
    "\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(dataloader, model, tokenizer, device='cpu'):\n",
    "    all_refs = []\n",
    "    all_hyps = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(dataloader, desc=\"评估中\"):\n",
    "            batch = to_device(batch, device)\n",
    "            audios = batch['audios']\n",
    "            audio_lens = batch['audio_lens']\n",
    "            texts = batch['texts']\n",
    "            text_lens = batch['text_lens']\n",
    "            \n",
    "            encoder_out, _, encoder_out_lens = model(audios, audio_lens, texts, text_lens)\n",
    "            \n",
    "            hyps = greedy_search(encoder_out, encoder_out_lens)\n",
    "            \n",
    "            for i in range(len(text_lens)):\n",
    "                ref = texts[i, :text_lens[i]].cpu().tolist()\n",
    "                all_refs.append(ref)\n",
    "                all_hyps.append(hyps[i])\n",
    "    \n",
    "    total_S = total_D = total_I = total_N = 0\n",
    "    for ref, hyp in zip(all_refs, all_hyps):\n",
    "        cer, S, D, I, N = calculate_cer(hyp, ref)\n",
    "        total_S += S\n",
    "        total_D += D\n",
    "        total_I += I\n",
    "        total_N += N\n",
    "    \n",
    "    final_cer = (total_S + total_D + total_I) / total_N if total_N > 0 else 1.0\n",
    "    \n",
    "    print(f\"评估结果:\")\n",
    "    print(f\"替换(S): {total_S}, 删除(D): {total_D}, 插入(I): {total_I}, 参考长度(N): {total_N}\")\n",
    "    print(f\"CER: {final_cer:.4f} ({total_S+total_D+total_I}/{total_N})\")\n",
    "\n",
    "    print(\"\\n样本对比:\")\n",
    "    for i in range(min(5, len(all_refs))):\n",
    "        print(f\"参考: {tokenizer.decode(all_refs[i])}\")\n",
    "        print(f\"预测: {tokenizer.decode(all_hyps[i])}\")\n",
    "        print()\n",
    "    \n",
    "    return final_cer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "评估中: 100%|██████████| 32/32 [00:04<00:00,  6.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "评估结果:\n",
      "替换(S): 825, 删除(D): 46, 插入(I): 85, 参考长度(N): 19252\n",
      "CER: 0.0497 (956/19252)\n",
      "\n",
      "样本对比:\n",
      "参考: ['yi', 'ge', 'nan', 'ren', 'tui', 'ran', 'de', 'zuo', 'zai', 'pang', 'bian', 'mu', 'guang', 'dai', 'zhi']\n",
      "预测: ['yi', 'gen', 'nai', 'ren', 'tui', 'ran', 'de', 'zuo', 'zai', 'pang', 'bian', 'mu', 'guan', 'dai', 'zhi']\n",
      "\n",
      "参考: ['xi', 'huan', 'ba', 'li', 'ao', 'de', 'shu', 'cha', 'zai', 'niu', 'zai', 'ku', 'de', 'qian', 'mian']\n",
      "预测: ['xi', 'huan', 'ba', 'li', 'ao', 'de', 'shu', 'cha', 'zai', 'niu', 'zai', 'ku', 'de', 'qian', 'mian']\n",
      "\n",
      "参考: ['zha', 'yan', 'yi', 'kan', 'xiang', 'qi', 'de', 'shi', 'guang', 'zhou', 'de', 'qu', 'hao', 'ling', 'e', 'er', 'ling']\n",
      "预测: ['zhe', 'ye', 'yi', 'kan', 'xiang', 'qi', 'de', 'shi', 'guang', 'zhou', 'de', 'xu', 'hao', 'liu', 'e', 'er', 'ling']\n",
      "\n",
      "参考: ['ci', 'qian', 'qing', 'hua', 'zi', 'guang', 'jiu', 'ceng', 'zao', 'yu', 'guo', 'tong', 'yang', 'de', 'wei', 'ji']\n",
      "预测: ['ci', 'qian', 'qing', 'hua', 'zi', 'guang', 'jiu', 'ceng', 'zao', 'yu', 'guo', 'tong', 'yang', 'de', 'wei', 'ji']\n",
      "\n",
      "参考: ['bu', 'tai', 'ming', 'bai', 'shen', 'me', 'yi', 'si']\n",
      "预测: ['bu', 'tai', 'ming', 'bai', 'shen', 'men', 'yin', 'si']\n",
      "\n",
      "最终CER: 0.0497\n"
     ]
    }
   ],
   "source": [
    "cer = evaluate_model(dev_dataloader, model, tokenizer, device)\n",
    "print(f\"最终CER: {cer:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
